{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.6.8\n",
      "IPython 7.2.0\n",
      "\n",
      "torch 1.1.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using PyTorch Dataset Loading Utilities for Custom Dataset -- Asian Face Dataset (AFAD)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook provides an example for how to prepare a custom dataset for PyTorch's data loading utilities. More in-depth information can be found in the official documentation at:\n",
    "\n",
    "- [Data Loading and Processing Tutorial](http://pytorch.org/tutorials/beginner/data_loading_tutorial.html)\n",
    "- [torch.utils.data](http://pytorch.org/docs/master/data.html) API documentation\n",
    "\n",
    "In this example, we are using the Asian Face Dataset (AFAD), which is a face image dataset with age labels [1]. There are two versions of this dataset, a smaller Lite version and the full version, which are available at\n",
    "\n",
    "- https://github.com/afad-dataset/tarball-lite\n",
    "- https://github.com/afad-dataset/tarball\n",
    "\n",
    "Here, we will be working with the Lite dataset, but the same code can be used for the full dataset as well -- the Lite \n",
    "dataset is just slightly smaller than the full dataset and thus faster to process.\n",
    "\n",
    "[1] Niu, Z., Zhou, M., Wang, L., Gao, X., & Hua, G. (2016). Ordinal regression with multiple output cnn for age estimation. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 4920-4928)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import SubsetRandomSampler\n",
    "from torch.utils.data import Dataset\n",
    "import torch.nn.functional as F\n",
    "import torch\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Downloading the Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following lines of code (bash commands) will download, unzip, and untar the dataset from GitHub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'tarball-lite'...\n",
      "remote: Enumerating objects: 37, done.\u001b[K\n",
      "remote: Total 37 (delta 0), reused 0 (delta 0), pack-reused 37\u001b[K\n",
      "Unpacking objects: 100% (37/37), done.\n",
      "Checking out files: 100% (30/30), done.\n"
     ]
    }
   ],
   "source": [
    "# Download\n",
    "!git clone https://github.com/afad-dataset/tarball-lite.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Join individual tars\n",
    "!cat tarball-lite/AFAD-Lite.tar.xz* > tarball-lite/AFAD-Lite.tar.xz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# \"Unzip\"\n",
    "!tar xf tarball-lite/AFAD-Lite.tar.xz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get image paths\n",
    "rootDir = 'AFAD-Lite'\n",
    "\n",
    "files = [os.path.relpath(os.path.join(dirpath, file), rootDir)\n",
    "         for (dirpath, dirnames, filenames) in os.walk(rootDir) \n",
    "         for file in filenames if file.endswith('.jpg')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of images in total: 59344\n"
     ]
    }
   ],
   "source": [
    "print(f'Number of images in total: {len(files)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating Label Files (CSVs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "d = {}\n",
    "\n",
    "d['age'] = []\n",
    "d['gender'] = []\n",
    "d['file'] = []\n",
    "d['path'] = []\n",
    "\n",
    "for f in files:\n",
    "    age, gender, fname = f.split('/')\n",
    "    if gender == '111':\n",
    "        gender = 'male'\n",
    "    else:\n",
    "        gender = 'female'\n",
    "        \n",
    "    d['age'].append(age)\n",
    "    d['gender'].append(gender)\n",
    "    d['file'].append(fname)\n",
    "    d['path'].append(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>gender</th>\n",
       "      <th>file</th>\n",
       "      <th>path</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>474596-0.jpg</td>\n",
       "      <td>39/112/474596-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>397477-0.jpg</td>\n",
       "      <td>39/112/397477-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>576466-0.jpg</td>\n",
       "      <td>39/112/576466-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>399405-0.jpg</td>\n",
       "      <td>39/112/399405-0.jpg</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>39</td>\n",
       "      <td>female</td>\n",
       "      <td>410524-0.jpg</td>\n",
       "      <td>39/112/410524-0.jpg</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  age  gender          file                 path\n",
       "0  39  female  474596-0.jpg  39/112/474596-0.jpg\n",
       "1  39  female  397477-0.jpg  39/112/397477-0.jpg\n",
       "2  39  female  576466-0.jpg  39/112/576466-0.jpg\n",
       "3  39  female  399405-0.jpg  39/112/399405-0.jpg\n",
       "4  39  female  410524-0.jpg  39/112/410524-0.jpg"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame.from_dict(d)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Normalize labels such that they start with `0`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'18'"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['age'].min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "df['age'] = df['age'].values.astype(int) - int(df['age'].min())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Seperate dataset into training and test subsets:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "msk = np.random.rand(len(df)) < 0.8\n",
    "df_train = df[msk]\n",
    "df_test = df[~msk]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Save data partitioning as CSV:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train.to_csv('training_set_lite.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_test.to_csv('test_set_lite.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of age labels: 22\n"
     ]
    }
   ],
   "source": [
    "num_ages = np.unique(df['age'].values).shape[0]\n",
    "print(f'Number of age labels: {num_ages}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of training examples: 47524\n",
      "Number of test examples: 11820\n"
     ]
    }
   ],
   "source": [
    "print(f'Number of training examples: {df_train.shape[0]}')\n",
    "print(f'Number of test examples: {df_test.shape[0]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Implementing a Custom Dataset Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AFADDatasetAge(Dataset):\n",
    "    \"\"\"Custom Dataset for loading AFAD face images\"\"\"\n",
    "\n",
    "    def __init__(self, csv_path, img_dir, transform=None):\n",
    "\n",
    "        df = pd.read_csv(csv_path)\n",
    "        self.img_dir = img_dir\n",
    "        self.csv_path = csv_path\n",
    "        self.df = df\n",
    "        self.y = df['age'].values\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(os.path.join(self.img_dir,\n",
    "                                      self.df.iloc[index]['path']))\n",
    "\n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "\n",
    "        label = self.y[index]\n",
    "\n",
    "        return img, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.y.shape[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Up DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "TRAIN_CSV_PATH = 'training_set_lite.csv'\n",
    "TEST_CSV_PATH = 'test_set_lite.csv'\n",
    "IMAGE_PATH = 'AFAD-Lite'\n",
    "BATCH_SIZE = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([128, 3, 120, 120])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "test_transform = transforms.Compose([transforms.Resize((128, 128)),\n",
    "                                     transforms.CenterCrop((120, 120)),\n",
    "                                     transforms.ToTensor()])\n",
    "\n",
    "test_dataset = AFADDatasetAge(csv_path=TEST_CSV_PATH,\n",
    "                               img_dir=IMAGE_PATH,\n",
    "                               transform=test_transform)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         num_workers=4,\n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in test_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_indices = torch.arange(0, 46000).numpy()\n",
    "valid_indices = torch.arange(46000, 47524).numpy()\n",
    "\n",
    "train_sampler = SubsetRandomSampler(train_indices)\n",
    "valid_sampler = SubsetRandomSampler(valid_indices)\n",
    "\n",
    "\n",
    "\n",
    "train_transform = transforms.Compose([transforms.Resize((128, 128)),\n",
    "                                      transforms.RandomCrop((120, 120)),\n",
    "                                      transforms.ToTensor()])\n",
    "\n",
    "test_transform = transforms.Compose([transforms.Resize((128, 128)),\n",
    "                                     transforms.CenterCrop((120, 120)),\n",
    "                                     transforms.ToTensor()])\n",
    "\n",
    "\n",
    "train_dataset = AFADDatasetAge(csv_path=TRAIN_CSV_PATH,\n",
    "                               img_dir=IMAGE_PATH,\n",
    "                               transform=train_transform)\n",
    "\n",
    "valid_dataset = AFADDatasetAge(csv_path=TRAIN_CSV_PATH,\n",
    "                               img_dir=IMAGE_PATH,\n",
    "                               transform=test_transform)\n",
    "\n",
    "test_dataset = AFADDatasetAge(csv_path=TEST_CSV_PATH,\n",
    "                               img_dir=IMAGE_PATH,\n",
    "                               transform=test_transform)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "train_loader = DataLoader(train_dataset,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=4,\n",
    "                          sampler=train_sampler)\n",
    "\n",
    "valid_loader = DataLoader(valid_dataset,\n",
    "                          batch_size=BATCH_SIZE,\n",
    "                          num_workers=4,\n",
    "                          sampler=valid_sampler)\n",
    "\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         num_workers=4,\n",
    "                         shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Image batch dimensions: torch.Size([128, 3, 120, 120])\n",
      "Image label dimensions: torch.Size([128])\n",
      "Image batch dimensions: torch.Size([128, 3, 120, 120])\n",
      "Image label dimensions: torch.Size([128])\n",
      "Image batch dimensions: torch.Size([128, 3, 120, 120])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "# Checking the dataset\n",
    "for images, labels in test_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break\n",
    "    \n",
    "for images, labels in valid_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break\n",
    "    \n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Iterating through the Custom Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 128\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 128\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.manual_seed(0)\n",
    "\n",
    "num_epochs = 2\n",
    "for epoch in range(num_epochs):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        break"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
